from PyQt5.QtCore import pyqtSignal, Qt
from PyQt5.QtGui import QPainter, QMouseEvent, QKeyEvent

from structures.basic_structure import BasicStructure
from trajectories.trajectory_factory import TrajectoryFactory
from view.layers.coordinate_layer import CoordinateLayer
from view.layers.record_layer import RecordLayer
from view.layers.shape_layer import Shape, ShapeLayer
from view.layers.simulator_layer import SimulatorLayer
from view.layers.trajectory_layer import TrajectoryLayer
from view.widget.base_overlay_widget import BaseOverlay
from utils.math_tool import *


class LegWidget(BaseOverlay):
    mouse_tracking_event = pyqtSignal(float, float)

    structure_changed_event = pyqtSignal()

    def __init__(self, parent):
        super(LegWidget, self).__init__(parent)

        self._is_free_drawing = False

        self._shape_drawing = Shape.NONE

        # 不按下也可以追踪鼠标位置
        self.setMouseTracking(True)

        # 画板偏移量
        self.canvas_offset_x = 0
        self.canvas_offset_y = 0

        self._ls = BasicStructure()

        self.trajectory_factory = TrajectoryFactory(self._ls)

        self.layers = {
            CoordinateLayer.LAYER_NAME: CoordinateLayer(self._ls),
            TrajectoryLayer.LAYER_NAME: TrajectoryLayer(self._ls),
            SimulatorLayer.LAYER_NAME: SimulatorLayer(self._ls),
            RecordLayer.LAYER_NAME: RecordLayer(self._ls),
            ShapeLayer.LAYER_NAME: ShapeLayer(self._ls),
        }

    def set_shape_drawing_finished(self, on_shape_finished):
        shape_layer: ShapeLayer = self.layers[ShapeLayer.LAYER_NAME]
        shape_layer.shape_finished_signal.connect(on_shape_finished)

    def set_shape_drawing(self, shape: Shape):
        self._shape_drawing = shape
        shape_layer: ShapeLayer = self.layers[ShapeLayer.LAYER_NAME]
        shape_layer.set_shape(shape)
        self.update()

    def get_shape_drawing(self):
        """
        获取当前正在绘制的形状
        -1 表示没有绘制
        0 表示绘制的是线
        1 表示绘制的是圆
        2 表示绘制的是三角形
        3 表示绘制的是矩形
        :return:
        """
        return self._shape_drawing

    def set_link_visible(self, link_name, state):
        self._ls.links_visible[link_name] = state != 0
        self.update()

    def set_coordinate_visible(self, visible):
        print("---------visible---------", visible)
        self.layers[CoordinateLayer.LAYER_NAME].visible = visible != 0
        self.update()

    def set_end_mode(self, end_mode_left):
        self._ls.set_left_mode(end_mode_left)
        self.update()

    def is_free_drawing(self):
        return self._is_free_drawing

    def set_free_drawing(self, is_free_drawing, record_points_list=None):
        self._is_free_drawing = is_free_drawing
        record_layer: RecordLayer = self.layers[RecordLayer.LAYER_NAME]
        record_layer.init_record(record_points_list)

    def update_last_record_traj_horizontal(self):
        record_layer: RecordLayer = self.layers[RecordLayer.LAYER_NAME]
        # record_layer.update_last_record_traj_horizontal()

    def get_final_free_drawing_points_list(self, for_save=False):
        record_layer: RecordLayer = self.layers[RecordLayer.LAYER_NAME]
        final_points_list = []
        for key_points in record_layer.record_points_list:
            # 过滤掉points_list中长度为0的元素
            if len(key_points) == 0:
                continue
            # 将所有子元素除以 ratio 得到最终用于保存的点
            if for_save:
                key_points = [self._ls.inverse_ratio(point) for point in key_points]

            final_points_list.append(key_points)

        return final_points_list

    def check_free_drawing_and_run(self):
        # 检查所有点是否都在可达范围内
        record_layer: RecordLayer = self.layers[RecordLayer.LAYER_NAME]
        all_reachable = record_layer.check_all_points_in_reachable_range()
        self.update()

        if not all_reachable:
            # 如果有点不在可达范围内，则提示用户
            return False

        # 如果所有点都在可达范围内，则运行轨迹
        return True

    def set_mode(self, mode):
        self._ls.set_mode(mode)
        self.update_and_notify()

    def get_structure(self):
        return self._ls

    def reset_links(self):
        self._ls.reset_links()
        self.update_and_notify()

    def update_angle_value(self, angle_name, degrees):
        self._ls.update_angle_value(angle_name, degrees)
        self.repaint()

    def update_link_value(self, link_name, value):
        self._ls.update_link_value(link_name, value)
        self.repaint()

    def trajectory_running(self):
        return self.trajectory_factory.is_any_trajectory_running()

    def trajectory_generate(self, track_key, *args):
        self.trajectory_factory.generate(track_key, self.update, *args)

    def trajectory_clean(self, track_key=None):
        if self.trajectory_factory.clean(track_key):
            self.update()
            return True
        return False

    def get_traj_joints(self):
        trajectory_list = self._ls.get_trajectory_list()

        # 从后往前遍历
        # for traj in reversed(trajectory_list):
        #     if isinstance(traj, BSplineCurveTrajectory):
        #         return traj.get_path_angles()
        #     elif isinstance(traj, FreeDrawingTrajectory):
        #         return traj.get_path_angles()
        rst_joints = []
        for traj in trajectory_list:
            rst_joints.extend([round_value(joint, 1) for joint in traj.get_path_angles()])

        return rst_joints

    def get_traj_points(self, for_save=False):
        trajectory_list = self._ls.get_trajectory_list()

        rst_points = []

        for traj in trajectory_list:
            points = traj.get_path_points()
            if for_save:
                new_points = []
                for point in points:
                    if len(point) == 1:
                        new_points.append(point)
                    elif len(point) > 1:
                        new_points.append(round_value(self._ls.inverse_ratio(point), 1))
                points = new_points

            rst_points.extend(points)

        return rst_points

    def draw_content(self, qp: QPainter):
        self.draw_test_rect(qp)

        qp.save()
        # 画板偏移量
        self.canvas_offset_x = int(self.width() / 2)
        self.canvas_offset_y = int(self.height() / 2 - self._ls.forward_ratio(30))
        # 将坐标系移到中心
        qp.translate(self.canvas_offset_x, self.canvas_offset_y)

        for layer in self.layers.values():
            if not layer.visible:
                continue

            if isinstance(layer, RecordLayer) and not self.is_free_drawing():
                continue

            layer.render(qp, self)

        qp.restore()

    def update_and_notify(self):
        self.update()
        self.structure_changed_event.emit()

    def _handle_mouse_event(self, event: QMouseEvent):
        x, y = event.x() - self.canvas_offset_x, event.y() - self.canvas_offset_y
        self.mouse_tracking_event.emit(self._ls.inverse_ratio(x), self._ls.inverse_ratio(y))

        # print("pos: {} type: {}".format((x, y), event.type()))

        if self.get_shape_drawing() != Shape.NONE:
            shape_layer: ShapeLayer = self.layers[ShapeLayer.LAYER_NAME]
            if shape_layer.handle_drawing_event(event, x, y):
                self.repaint()
            return

        if self.is_free_drawing():
            record_layer: RecordLayer = self.layers[RecordLayer.LAYER_NAME]
            if record_layer.handle_drawing_event(event, x, y):
                self.repaint()
            return

        self._ls.handleMouseEvent(event.type(), event.button(), np.array([x, y]), self.update_and_notify)

    mousePressEvent = _handle_mouse_event
    mouseMoveEvent = _handle_mouse_event
    mouseReleaseEvent = _handle_mouse_event
